package plus.easydo.starter.oauth.server.service;

import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.oauth2.common.exceptions.InvalidGrantException;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.security.oauth2.common.util.SerializationUtils;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.code.AuthorizationCodeServices;
import plus.easydo.starter.oauth.server.model.OauthCode;
import plus.easydo.starter.oauth.server.properties.Oauth2ServerProperties;
import plus.easydo.starter.oauth.server.serializer.FastJson2RedisTokenStoreSerializationStrategy;
import plus.easydo.starter.redis.RedisUtils;

import javax.annotation.Resource;
import java.util.concurrent.TimeUnit;

/**
 * 实现AuthorizationCodeServices
 * 自定义code授权码流程
 * 使用redis存储
 * @author yuzhanfeng
 */
@EnableConfigurationProperties(Oauth2ServerProperties.class)
@Configuration
public class CustomizeAuthorizationCodeServices implements AuthorizationCodeServices {


    /**
     * 自定义授权码长度，默认为12位，
     */
    private final RandomValueStringGenerator randomValueStringGenerator = new RandomValueStringGenerator(12);

    private final FastJson2RedisTokenStoreSerializationStrategy serializations
            = new FastJson2RedisTokenStoreSerializationStrategy<>(OAuth2Authentication.class);

    @Resource
    Oauth2ServerProperties oAuth2Properties;
    /**
     * 使用redis存储代替数据库存储
     */
    @Resource
    RedisUtils<Object> redisUtils;

    /**
     * 创建授权码
     * @param oAuth2Authentication 身份认证
     * @return String
     */
    @Override
    public String createAuthorizationCode(OAuth2Authentication oAuth2Authentication) {
        String clientId = oAuth2Authentication.getOAuth2Request().getClientId();
        return saveCode(clientId, oAuth2Authentication);
    }


    /**
     * 使用授权码
     * @param code 授权码
     * @return OAuth2Authentication 身份认证
     */
    @Override
    public OAuth2Authentication consumeAuthorizationCode(String code){
        String oauthCodePrefix = oAuth2Properties.getOauthCodePrefix();
        OauthCode oauthCode = (OauthCode) redisUtils.get(oauthCodePrefix + code);
        if(oauthCode != null){
            remove(code,oauthCode.getClientId());
            //反序列化为OAuth2Authentication并返回
            if(oAuth2Properties.isEnableFastJsonSerializer()){
                return (OAuth2Authentication) serializations.deserialize(oauthCode.getAuthentication(),OAuth2Authentication.class);
            }else {
                return SerializationUtils.deserialize(oauthCode.getAuthentication());
            }
        }else {
            throw new InvalidGrantException("授权码无效: " + code);
        }
    }


    /**
     * 存储授权码
     * @param oAuth2Authentication oAuth2Authentication
     */
    private String saveCode(String clientId, OAuth2Authentication oAuth2Authentication) {
        String codeClientPrefix = oAuth2Properties.getCodeClientPrefix();
        String oauthCodePrefix = oAuth2Properties.getOauthCodePrefix();
        String cacheCode = validatorClientIdCode(clientId,codeClientPrefix,oauthCodePrefix);
        if (cacheCode != null) {
            return cacheCode;
        }
        /*初始化一个自定义的code对象*/
        String code = randomValueStringGenerator.generate();
        OauthCode oauthCode = new OauthCode();
        oauthCode.setClientId(clientId);
        oauthCode.setCode(code);
        /*将oAuth2Authentication序列化为byte*/
        if (oAuth2Properties.isEnableFastJsonSerializer()) {
            //noinspection unchecked
            oauthCode.setAuthentication(serializations.serialize(oAuth2Authentication));
        } else {
            oauthCode.setAuthentication(SerializationUtils.serialize(oAuth2Authentication));
        }
        /*存储一份客户端凭证，用于校验是否已存在code但未消费*/
        redisUtils.set(codeClientPrefix + clientId, code);
        redisUtils.expire(codeClientPrefix + clientId, 5, TimeUnit.MINUTES);
        /*存储code 设置有效时间*/
        redisUtils.set(oauthCodePrefix + code, oauthCode);
        redisUtils.expire(oauthCodePrefix + code, 5, TimeUnit.MINUTES);
        return code;
    }


    /**
     * 删除授权码
     * @param code 授权码
     * @param clientId 客户端id
     */
    private void remove(String code, String clientId){
        /*删除code*/
        redisUtils.delete(oAuth2Properties.getOauthCodePrefix() +code);
        /*删除用于校验code的客户端信息*/
        redisUtils.delete(oAuth2Properties.getCodeClientPrefix() +clientId);
    }

    /**
     * 根据clientId判断是否已经授权过code但未消费，
     *  如果存在则续期，防止无限获取授权码导致redis内存占用过高
     * @param clientId 客户端id
     * @return 结果
     */
    private String validatorClientIdCode(String clientId, String codeClientPrefix, String oauthCodePrefix) {
        String cacheCode = (String) redisUtils.get(codeClientPrefix + clientId);
        if (cacheCode != null) {
            OauthCode oauthCode = (OauthCode) redisUtils.get(oauthCodePrefix + cacheCode);
            if (oauthCode != null) {
                /*续期客户端凭证*/
                redisUtils.expire(codeClientPrefix + clientId,5,TimeUnit.MINUTES);
                /*续期code*/
                redisUtils.expire(oauthCodePrefix + cacheCode,5,TimeUnit.MINUTES);
                return oauthCode.getCode();
            }
        }
        return null;
    }


}
